package main

import (
	"bytes"
	"flag"
	"fmt"
	"go/format"
	"io/ioutil"
	"log"
	"os"
	"regexp"
	"strings"
	"unicode"
)

func main() {
	log.SetFlags(0)
	log.SetPrefix("genmark: ")

	if err := run(); err != nil {
		if _, ok := err.(*flagError); ok {
			flag.Usage()
		}
		log.Fatal(err)
	}
}

type flagError struct {
	message string
}

func (e *flagError) Error() string {
	return e.message
}

func FlagError(format string, a ...interface{}) *flagError {
	return &flagError{message: fmt.Sprintf(format, a...)}
}

var snakeRegex = regexp.MustCompile("(^|[a-z])([A-Z])")

func toSnake(s string) string {
	s = strings.Replace(s, "SQL", "Sql", -1)
	return snakeRegex.ReplaceAllStringFunc(s, func(s string) string {
		r := []rune(s)
		if len(r) == 1 {
			r[0] = unicode.ToLower(r[0])
		} else {
			r = append(r[:1], '_', unicode.ToLower(r[1]))
		}
		return string(r)
	})
}

func toPrivate(s string) string {
	if strings.HasPrefix(s, "SQL") {
		s = strings.Replace(s, "SQL", "sql", 1)
	}
	r := []rune(s)
	r[0] = unicode.ToLower(r[0])
	return string(r)
}

func run() error {
	var flags struct {
		MarkerTypeName string
		OutputName     string
		Embedded       string
		Package        string
	}

	flag.StringVar(&flags.MarkerTypeName, "t", "", "marker interface type name (required)")
	flag.StringVar(&flags.OutputName, "o", "", "output filename (automatically add '_gen.go')")
	flag.StringVar(&flags.Embedded, "e", "", "embedded struct list (comma separated)")
	flag.StringVar(&flags.Package, "pkg", os.Getenv("GOPACKAGE"), "package name")
	flag.Parse()

	markerTypeName := flags.MarkerTypeName
	if markerTypeName == "" {
		return FlagError("-t is must be required")
	}
	if !unicode.IsUpper([]rune(markerTypeName)[0]) {
		return FlagError("-t is must be public")
	}
	implTypeName := toPrivate(markerTypeName)

	outputName := flags.OutputName
	if outputName == "" {
		outputName = toSnake(markerTypeName)
	}
	outputName += "_gen.go"

	var embedded []string
	if len(flags.Embedded) != 0 {
		embedded = strings.Split(flags.Embedded, ",")
	}

	buf := &bytes.Buffer{}
	fmt.Fprintf(buf, "package %s\n", flags.Package)
	fmt.Fprintf(buf, "// Code generated by genmark. DO NOT EDIT.\n\n")
	fmt.Fprintf(buf, "type %s interface {\n", markerTypeName)
	fmt.Fprintf(buf, "%sMarker()\n", implTypeName)
	for _, e := range embedded {
		fmt.Fprintln(buf, e)
	}
	fmt.Fprintf(buf, "}\n")
	fmt.Fprintf(buf, "type %s struct {}\n", implTypeName)
	fmt.Fprintf(buf, "func (%s) %sMarker() {}\n", implTypeName, implTypeName)

	src, err := format.Source(buf.Bytes())
	if err != nil {
		return fmt.Errorf("failed to format source code: %s", err.Error())
	}

	err = ioutil.WriteFile(outputName, src, 0666)
	if err != nil {
		return fmt.Errorf("failed to write generate code: %s", err.Error())
	}
	return nil
}
